import os
import gc
import re
import math
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from tqdm import tqdm

# —— 表格比例修正函数（最大余数法） ——
def fix_table_proportions(text: str) -> str:
    match = re.search(r"(\|\s*用地类型.*?)(?:计算：|\Z)", text, flags=re.S)
    if not match:
        return text
    table_block = match.group(1)
    lines = table_block.strip().splitlines()
    if len(lines) < 3:
        print("⚠️ 表格行数不足，跳过比例修正。")
        return text

    header, sep, *rows = lines
    types, vals = [], []
    for row in rows:
        m = re.match(r"\|\s*(.+?)\s*\|\s*([\d\.]+)%\s*\|", row)
        if m:
            types.append(m.group(1))
            vals.append(float(m.group(2)))

    total = sum(vals)
    if total <= 0:
        return text

    raw_percents = [v / total * 100 for v in vals]
    floor_percents = [math.floor(r) for r in raw_percents]
    remainders = [r - f for r, f in zip(raw_percents, floor_percents)]
    to_allocate = 100 - sum(floor_percents)
    idxs = sorted(range(len(remainders)), key=lambda i: remainders[i], reverse=True)
    for i in idxs[:to_allocate]:
        floor_percents[i] += 1

    new_rows = [f"| {t:<22} | {p:>3d}% |" for t, p in zip(types, floor_percents)]
    new_table = "\n".join([header, sep] + new_rows)
    calc_line = "计算：" + "+".join(f"{p}%" for p in floor_percents) + " = 100%"

    return text.replace(table_block, new_table + "\n\n" + calc_line + "\n\n")

# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ★★★ 路径：基础模型 + 【融合】DPO LoRA 目录 ★★★
base_model_path = r"D:\chatglm3-models\chatglm3-6b"  # 底模
lora_adapter_dir = r"D:\chongxinxuexi\pythonProject1\dpo_model\fusion_from_expert"  # 融合后的 LoRA 适配器目录

print("🚀 开始加载模型和分词器...", flush=True)
# 优先用适配器目录中的分词器（若没有则回退到底模目录）
try:
    tokenizer = AutoTokenizer.from_pretrained(lora_adapter_dir, trust_remote_code=True, use_fast=True)
except Exception:
    tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True, use_fast=True)

# 建议右填充 + 补齐 pad_token
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 先加载底模，再套 LoRA 适配器
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_path, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else None
).to(device)
model = PeftModel.from_pretrained(base_model, lora_adapter_dir).to(device)
model.eval()
print("✅ 模型加载完成！\n", flush=True)

# 数据路径（保持你原来的设置）
data_path = r"D:\chongxinxuexi\pythonProject1\data\最终评估数据1.xlsx"
df = pd.read_excel(data_path, engine="openpyxl")
if "prompt" not in df.columns:
    raise ValueError("❌ Excel 文件中必须包含名为 'prompt' 的一列")

# —— 输出到融合目录下 ——
output_path = r"D:\chongxinxuexi\pythonProject1\dpo_model\fusion_from_expert\生成方案\融合dpo模型生成评估数据的方案.xlsx"
os.makedirs(os.path.dirname(output_path), exist_ok=True)  # 确保输出目录存在

print(f"📄 读取到 {len(df)} 条 Prompt，开始生成方案...\n", flush=True)

# 生成设置（确定性：Beam Search）
split_token = "### 以下是改造方案正文："
generate_kwargs = {
    "max_new_tokens": 800,
    "do_sample": False,
    "num_beams": 4,
    "num_return_sequences": 1,
    "eos_token_id": tokenizer.eos_token_id,
    "pad_token_id": tokenizer.pad_token_id,
}

strategy_prompts = [
    "请生成一个更新方案，逻辑完整，突出重点，符合城市更新政策与技术规范。"
]

# 主生成循环
results = []
for idx, prompt in enumerate(tqdm(df['prompt'].astype(str), desc="生成进度"), start=1):
    prompt = prompt.strip()
    tqdm.write(f"\n🟢 Prompt #{idx}: {prompt}\n")
    entry = {"prompt": prompt}

    for i, strat in enumerate(strategy_prompts):
        prompt_text = (
            f"{prompt}\n\n"
            f"{strat}\n\n"
            "请严格按照以下结构输出：一、规划思路（填写总体定位、核心目标与空间引导逻辑）；二、空间布局优化思路（条目式列出关键优化思路，不使用第一人称）；三、主要更新举措（条目式列出关键更新策略与实施路径，不使用第一人称）；四、土地利用规划（使用纯文本表格填写用地类型及比例，总和必须为100%，表格头为| 用地类型 | 比例 |，校验：总和=100%）；请完整输出上述内容，并自行检查。\n\n"
            f"{split_token}\n\n"
        )

        inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=2048)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            output_ids = model.generate(**inputs, **generate_kwargs)
        text_out = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

        scheme = text_out.split(split_token, 1)[1].strip() if split_token in text_out else text_out.strip()

        # 清理 markdown 和表格代码块
        scheme = re.sub(r"^```[^\n]*\n", "", scheme)
        scheme = re.sub(r"\n```\s*$", "", scheme)
        scheme = re.sub(r"```.*?```", "", scheme, flags=re.S)

        # 比例修正
        scheme = fix_table_proportions(scheme)
        print(f"✅ 方案生成完成。\n")

        label = f"方案{i+1}"
        tqdm.write(f"—— {label} ——")
        tqdm.write(scheme + "\n")
        entry[label] = scheme

    results.append(entry)

    if idx % 10 == 0:
        pd.DataFrame(results).to_excel(output_path, index=False, engine="openpyxl")
        tqdm.write(f"✅ 已保存前 {idx} 条结果至：{output_path}")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

# 最终保存
pd.DataFrame(results).to_excel(output_path, index=False, engine="openpyxl")
tqdm.write(f"\n🎉 所有方案已生成完毕，结果保存至：{output_path}")
